{
 "cells": [
  
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "\n",
    "# Select a fixed random seed for reproducibility\n",
    "mx.random.seed(42)\n",
    "\n",
    "def data_xform(data):\n",
    "    \"\"\"Move channel axis to the beginning, cast to float32, and normalize to [0, 1].\"\"\"\n",
    "    return nd.moveaxis(data, 2, 0).astype('float32') / 255\n",
    "\n",
    "train_data = mx.gluon.data.vision.MNIST(train=True).transform_first(data_xform)\n",
    "val_data = mx.gluon.data.vision.MNIST(train=False).transform_first(data_xform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 100\n",
    "train_loader = mx.gluon.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)\n",
    "val_loader = mx.gluon.data.DataLoader(val_data, shuffle=False, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function  # only relevant for Python 2\n",
    "import mxnet as mx\n",
    "from mxnet import nd, gluon, autograd\n",
    "from mxnet.gluon import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.HybridSequential(prefix='MLP_')\n",
    "with net.name_scope():\n",
    "    net.add(\n",
    "        nn.Flatten(),\n",
    "        nn.Dense(128, activation='relu'),\n",
    "        nn.Dense(64, activation='relu'),\n",
    "        nn.Dense(10, activation=None)  # loss function includes softmax already, see below\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ctx = mx.gpu(0) if mx.context.num_gpus() > 0 else mx.cpu(0)\n",
    "net.initialize(mx.init.Xavier(), ctx=ctx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = gluon.Trainer(\n",
    "    params=net.collect_params(),\n",
    "    optimizer='sgd',\n",
    "    optimizer_params={'learning_rate': 0.04},\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = mx.metric.Accuracy()\n",
    "loss_function = gluon.loss.SoftmaxCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 10\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for inputs, labels in train_loader:\n",
    "        # Possibly copy inputs and labels to the GPU\n",
    "        inputs = inputs.as_in_context(ctx)\n",
    "        labels = labels.as_in_context(ctx)\n",
    "\n",
    "        # The forward pass and the loss computation need to be wrapped\n",
    "        # in a `record()` scope to make sure the computational graph is\n",
    "        # recorded in order to automatically compute the gradients\n",
    "        # during the backward pass.\n",
    "        with autograd.record():\n",
    "            outputs = net(inputs)\n",
    "            loss = loss_function(outputs, labels)\n",
    "\n",
    "        # Compute gradients by backpropagation and update the evaluation\n",
    "        # metric\n",
    "        loss.backward()\n",
    "        metric.update(labels, outputs)\n",
    "\n",
    "        # Update the parameters by stepping the trainer; the batch size\n",
    "        # is required to normalize the gradients by `1 / batch_size`.\n",
    "        trainer.step(batch_size=inputs.shape[0])\n",
    "\n",
    "    # Print the evaluation metric and reset it for the next epoch\n",
    "    name, acc = metric.get()\n",
    "    print('After epoch {}: {} = {}'.format(epoch + 1, name, acc))\n",
    "    metric.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = mx.metric.Accuracy()\n",
    "for inputs, labels in val_loader:\n",
    "    # Possibly copy inputs and labels to the GPU\n",
    "    inputs = inputs.as_in_context(ctx)\n",
    "    labels = labels.as_in_context(ctx)\n",
    "    metric.update(labels, net(inputs))\n",
    "print('Validaton: {} = {}'.format(*metric.get()))\n",
    "assert metric.get()[1] > 0.96"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mislabeled(loader):\n",
    "    \"\"\"Return list of ``(input, pred_lbl, true_lbl)`` for mislabeled samples.\"\"\"\n",
    "    mislabeled = []\n",
    "    for inputs, labels in loader:\n",
    "        inputs = inputs.as_in_context(ctx)\n",
    "        labels = labels.as_in_context(ctx)\n",
    "        outputs = net(inputs)\n",
    "        # Predicted label is the index is where the output is maximal\n",
    "        preds = nd.argmax(outputs, axis=1)\n",
    "        for i, p, l in zip(inputs, preds, labels):\n",
    "            p, l = int(p.asscalar()), int(l.asscalar())\n",
    "            if p != l:\n",
    "                mislabeled.append((i.asnumpy(), p, l))\n",
    "    return mislabeled"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "sample_size = 8\n",
    "wrong_train = get_mislabeled(train_loader)\n",
    "wrong_val = get_mislabeled(val_loader)\n",
    "wrong_train_sample = [wrong_train[i] for i in np.random.randint(0, len(wrong_train), size=sample_size)]\n",
    "wrong_val_sample = [wrong_val[i] for i in np.random.randint(0, len(wrong_val), size=sample_size)]\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, axs = plt.subplots(ncols=sample_size)\n",
    "for ax, (img, pred, lbl) in zip(axs, wrong_train_sample):\n",
    "    fig.set_size_inches(18, 4)\n",
    "    fig.suptitle(\"Sample of wrong predictions in the training set\", fontsize=20)\n",
    "    ax.imshow(img[0], cmap=\"gray\")\n",
    "    ax.set_title(\"Predicted: {}\\nActual: {}\".format(pred, lbl))\n",
    "    ax.xaxis.set_visible(False)\n",
    "    ax.yaxis.set_visible(False)\n",
    "\n",
    "fig, axs = plt.subplots(ncols=sample_size)\n",
    "for ax, (img, pred, lbl) in zip(axs, wrong_val_sample):\n",
    "    fig.set_size_inches(18, 4)\n",
    "    fig.suptitle(\"Sample of wrong predictions in the validation set\", fontsize=20)\n",
    "    ax.imshow(img[0], cmap=\"gray\")\n",
    "    ax.set_title(\"Predicted: {}\\nActual: {}\".format(pred, lbl))\n",
    "    ax.xaxis.set_visible(False)\n",
    "    ax.yaxis.set_visible(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conv_layer = nn.Conv2D(kernel_size=(3, 3), channels=32, in_channels=16, activation='relu')\n",
    "print(conv_layer.params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lenet = nn.HybridSequential(prefix='LeNet_')\n",
    "with lenet.name_scope():\n",
    "    lenet.add(\n",
    "        nn.Conv2D(channels=20, kernel_size=(5, 5), activation='tanh'),\n",
    "        nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),\n",
    "        nn.Conv2D(channels=50, kernel_size=(5, 5), activation='tanh'),\n",
    "        nn.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),\n",
    "        nn.Flatten(),\n",
    "        nn.Dense(500, activation='tanh'),\n",
    "        nn.Dense(10, activation=None),\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lenet.initialize(mx.init.Xavier(), ctx=ctx)\n",
    "lenet.summary(nd.zeros((1, 1, 28, 28), ctx=ctx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = gluon.Trainer(\n",
    "    params=lenet.collect_params(),\n",
    "    optimizer='sgd',\n",
    "    optimizer_params={'learning_rate': 0.04},\n",
    ")\n",
    "metric = mx.metric.Accuracy()\n",
    "num_epochs = 10\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    for inputs, labels in train_loader:\n",
    "        inputs = inputs.as_in_context(ctx)\n",
    "        labels = labels.as_in_context(ctx)\n",
    "\n",
    "        with autograd.record():\n",
    "            outputs = lenet(inputs)\n",
    "            loss = loss_function(outputs, labels)\n",
    "\n",
    "        loss.backward()\n",
    "        metric.update(labels, outputs)\n",
    "\n",
    "        trainer.step(batch_size=inputs.shape[0])\n",
    "\n",
    "    name, acc = metric.get()\n",
    "    print('After epoch {}: {} = {}'.format(epoch + 1, name, acc))\n",
    "    metric.reset()\n",
    "\n",
    "for inputs, labels in val_loader:\n",
    "    inputs = inputs.as_in_context(ctx)\n",
    "    labels = labels.as_in_context(ctx)\n",
    "    metric.update(labels, lenet(inputs))\n",
    "print('Validaton: {} = {}'.format(*metric.get()))\n",
    "assert metric.get()[1] > 0.985"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
